{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "gnxCiVBlUUwm"
      },
      "outputs": [],
      "source": [
        "#@title ##### License\n",
        "# Copyright 2018 The GraphNets Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#    http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# ============================================================================"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tRbPqRK12xp7"
      },
      "source": [
        "# Sort a list of elements\n",
        "\n",
        "This notebook and the accompanying code demonstrates how to use the Graph Nets library to\n",
        "learn to sort a list of elements.\n",
        "\n",
        "The list of elements is treated as a fully connected graph, and the task is to\n",
        "label the nodes and edges as a linked list. The network is trained to label the\n",
        "start node, and which (directed) edges correspond to the links to the next\n",
        "largest element, for each node.\n",
        "\n",
        "After training, the network's prediction ability is illustrated by comparing its\n",
        "output to the true sorted list. Then the network's ability to generalise is\n",
        "tested, by using it to sort larger lists."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "OkRrv8l1Uhza"
      },
      "outputs": [],
      "source": [
        "#@title ### Install the Graph Nets library on this Colaboratory runtime  { form-width: \"60%\", run: \"auto\"}\n",
        "#@markdown \u003cbr\u003e1. Connect to a local or hosted Colaboratory runtime by clicking the **Connect** button at the top-right.\u003cbr\u003e2. Choose \"Yes\" below to install the Graph Nets library on the runtime machine with the correct dependencies. Note, this works both with local and hosted Colaboratory runtimes.\n",
        "\n",
        "install_graph_nets_library = \"No\"  #@param [\"Yes\", \"No\"]\n",
        "\n",
        "if install_graph_nets_library.lower() == \"yes\":\n",
        "  print(\"Installing Graph Nets library and dependencies:\")\n",
        "  print(\"Output message from command:\\n\")\n",
        "  !pip install \"graph_nets\u003e=1.1\" \"dm-sonnet\u003e=2.0.0b0\"\n",
        "else:\n",
        "  print(\"Skipping installation of Graph Nets library\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "goK_7YhSVBZ-"
      },
      "source": [
        "### Install dependencies locally\n",
        "\n",
        "If you are running this notebook locally (i.e., not through Colaboratory), you will also need to install a few more dependencies. Run the following on the command line to install the graph networks library, as well as a few other dependencies:\n",
        "\n",
        "```\n",
        "pip install \"graph_nets\u003e=1.1\" \"tensorflow\u003e=2.1.0-rc1\" \"dm-sonnet\u003e=2.0.0b0\" tensorflow_probability\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OD_uoUHdGp5t"
      },
      "source": [
        "# Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "8moWllwb-yZr"
      },
      "outputs": [],
      "source": [
        "#@title Imports  { form-width: \"30%\" }\n",
        "%tensorflow_version 2.x  # For Google Colab only.\n",
        "\n",
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import time\n",
        "\n",
        "\n",
        "from graph_nets import utils_np\n",
        "from graph_nets import utils_tf\n",
        "from graph_nets.demos_tf2 import models\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "import sonnet as snt\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED)\n",
        "tf.random.set_seed(SEED)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "MkwjeiSdLMrf"
      },
      "outputs": [],
      "source": [
        "#@title Helper functions  { form-width: \"30%\" }\n",
        "\n",
        "# pylint: disable=redefined-outer-name\n",
        "\n",
        "def create_graph_dicts_tf(num_examples, num_elements_min_max):\n",
        "  \"\"\"Generate graphs for training.\n",
        "\n",
        "  Args:\n",
        "    num_examples: total number of graphs to generate\n",
        "    num_elements_min_max: a 2-tuple with the minimum and maximum number of\n",
        "      values allowable in a graph. The number of values for a graph is\n",
        "      uniformly sampled withing this range. The upper bound is exclusive, and\n",
        "      should be at least 2 more than the lower bound.\n",
        "\n",
        "  Returns:\n",
        "    inputs: contains the generated random numbers as node values.\n",
        "    sort_indices: contains the sorting indices as nodes. Concretely\n",
        "      inputs.nodes[sort_indices.nodes] will be a sorted array.\n",
        "    ranks: the rank of each value in inputs normalized to the range [0, 1].\n",
        "  \"\"\"\n",
        "  num_elements = tf.random.uniform(\n",
        "      [num_examples],\n",
        "      minval=num_elements_min_max[0],\n",
        "      maxval=num_elements_min_max[1],\n",
        "      dtype=tf.int32)\n",
        "  inputs_graphs = []\n",
        "  sort_indices_graphs = []\n",
        "  ranks_graphs = []\n",
        "  for i in range(num_examples):\n",
        "    values = tf.random.uniform(shape=[num_elements[i]])\n",
        "    sort_indices = tf.cast(\n",
        "        tf.argsort(values, axis=-1), tf.float32)\n",
        "    ranks = tf.cast(\n",
        "        tf.argsort(sort_indices, axis=-1), tf.float32) / (\n",
        "            tf.cast(num_elements[i], tf.float32) - 1.0)\n",
        "    inputs_graphs.append({\"nodes\": values[:, None]})\n",
        "    sort_indices_graphs.append({\"nodes\": sort_indices[:, None]})\n",
        "    ranks_graphs.append({\"nodes\": ranks[:, None]})\n",
        "  return inputs_graphs, sort_indices_graphs, ranks_graphs\n",
        "\n",
        "\n",
        "def create_linked_list_target(batch_size, input_graphs):\n",
        "  \"\"\"Creates linked list targets.\n",
        "\n",
        "  Returns a graph with the same number of nodes as `input_graph`. Each node\n",
        "  contains a 2d vector with targets for a 1-class classification where only one\n",
        "  node is `True`, the smallest value in the array. The vector contains two\n",
        "  values: [prob_true, prob_false].\n",
        "  It also contains edges connecting all nodes. These are again 2d vectors with\n",
        "  softmax targets [prob_true, prob_false]. An edge is True\n",
        "  if n+1 is the element immediately after n in the sorted list.\n",
        "\n",
        "  Args:\n",
        "    batch_size: batch size for the `input_graphs`.\n",
        "    input_graphs: a `graphs.GraphsTuple` which contains a batch of inputs.\n",
        "\n",
        "  Returns:\n",
        "    A `graphs.GraphsTuple` with the targets, which encode the linked list.\n",
        "  \"\"\"\n",
        "  target_graphs = []\n",
        "  for i in range(batch_size):\n",
        "    input_graph = utils_tf.get_graph(input_graphs, i)\n",
        "    num_elements = tf.shape(input_graph.nodes)[0]\n",
        "    si = tf.cast(tf.squeeze(input_graph.nodes), tf.int32)\n",
        "    nodes = tf.reshape(tf.one_hot(si[:1], num_elements), (-1, 1))\n",
        "    x = tf.stack((si[:-1], si[1:]))[None]\n",
        "    y = tf.stack(\n",
        "        (input_graph.senders, input_graph.receivers), axis=1)[:, :, None]\n",
        "    edges = tf.reshape(\n",
        "        tf.cast(\n",
        "            tf.reduce_any(tf.reduce_all(tf.equal(x, y), axis=1), axis=1),\n",
        "            tf.float32), (-1, 1))\n",
        "    target_graphs.append(input_graph._replace(nodes=nodes, edges=edges))\n",
        "  return utils_tf.concat(target_graphs, axis=0)\n",
        "\n",
        "\n",
        "def compute_accuracy(target, output):\n",
        "  \"\"\"Calculate model accuracy.\n",
        "\n",
        "  Returns the number of correctly predicted links and the number\n",
        "  of completely solved list sorts (100% correct predictions).\n",
        "\n",
        "  Args:\n",
        "    target: A `graphs.GraphsTuple` that contains the target graph.\n",
        "    output: A `graphs.GraphsTuple` that contains the output graph.\n",
        "\n",
        "  Returns:\n",
        "    correct: A `float` fraction of correctly labeled nodes/edges.\n",
        "    solved: A `float` fraction of graphs that are completely correctly labeled.\n",
        "  \"\"\"\n",
        "  tdds = utils_np.graphs_tuple_to_data_dicts(target)\n",
        "  odds = utils_np.graphs_tuple_to_data_dicts(output)\n",
        "  cs = []\n",
        "  ss = []\n",
        "  for td, od in zip(tdds, odds):\n",
        "    num_elements = td[\"nodes\"].shape[0]\n",
        "    xn = np.argmax(td[\"nodes\"], axis=-1)\n",
        "    yn = np.argmax(od[\"nodes\"], axis=-1)\n",
        "\n",
        "    xe = np.reshape(\n",
        "        np.argmax(\n",
        "            np.reshape(td[\"edges\"], (num_elements, num_elements, 2)), axis=-1),\n",
        "        (-1,))\n",
        "    ye = np.reshape(\n",
        "        np.argmax(\n",
        "            np.reshape(od[\"edges\"], (num_elements, num_elements, 2)), axis=-1),\n",
        "        (-1,))\n",
        "    c = np.concatenate((xn == yn, xe == ye), axis=0)\n",
        "    s = np.all(c)\n",
        "    cs.append(c)\n",
        "    ss.append(s)\n",
        "  correct = np.mean(np.concatenate(cs, axis=0))\n",
        "  solved = np.mean(np.stack(ss))\n",
        "  return correct, solved\n",
        "\n",
        "\n",
        "def create_data(batch_size, num_elements_min_max):\n",
        "  \"\"\"Returns graphs containing the inputs and targets for classification.\n",
        "\n",
        "  Refer to create_data_dicts_tf and create_linked_list_target for more details.\n",
        "\n",
        "  Args:\n",
        "    batch_size: batch size for the `input_graphs`.\n",
        "    num_elements_min_max: a 2-`tuple` of `int`s which define the [lower, upper)\n",
        "      range of the number of elements per list.\n",
        "\n",
        "  Returns:\n",
        "    inputs: a `graphs.GraphsTuple` which contains the input list as a graph.\n",
        "    targets: a `graphs.GraphsTuple` which contains the target as a graph.\n",
        "    sort_indices: a `graphs.GraphsTuple` which contains the sort indices of\n",
        "      the list elements a graph.\n",
        "    ranks: a `graphs.GraphsTuple` which contains the ranks of the list\n",
        "      elements as a graph.\n",
        "  \"\"\"\n",
        "  inputs, sort_indices, ranks = create_graph_dicts_tf(\n",
        "      batch_size, num_elements_min_max)\n",
        "  inputs = utils_tf.data_dicts_to_graphs_tuple(inputs)\n",
        "  sort_indices = utils_tf.data_dicts_to_graphs_tuple(sort_indices)\n",
        "  ranks = utils_tf.data_dicts_to_graphs_tuple(ranks)\n",
        "\n",
        "  inputs = utils_tf.fully_connect_graph_dynamic(inputs)\n",
        "  sort_indices = utils_tf.fully_connect_graph_dynamic(sort_indices)\n",
        "  ranks = utils_tf.fully_connect_graph_dynamic(ranks)\n",
        "\n",
        "  targets = create_linked_list_target(batch_size, sort_indices)\n",
        "  nodes = tf.concat((targets.nodes, 1.0 - targets.nodes), axis=1)\n",
        "  edges = tf.concat((targets.edges, 1.0 - targets.edges), axis=1)\n",
        "  targets = targets._replace(nodes=nodes, edges=edges)\n",
        "\n",
        "  return inputs, targets, sort_indices, ranks\n",
        "\n",
        "\n",
        "def create_loss(target, outputs):\n",
        "  \"\"\"Returns graphs containing the inputs and targets for classification.\n",
        "\n",
        "  Refer to create_data_dicts_tf and create_linked_list_target for more details.\n",
        "\n",
        "  Args:\n",
        "    target: a `graphs.GraphsTuple` which contains the target as a graph.\n",
        "    outputs: a `list` of `graphs.GraphsTuple`s which contains the model\n",
        "      outputs for each processing step as graphs.\n",
        "\n",
        "  Returns:\n",
        "    A `list` of ops which are the loss for each processing step.\n",
        "  \"\"\"\n",
        "  # if not isinstance(outputs, collections.Sequence):\n",
        "  #   outputs = [outputs]\n",
        "  losss = [\n",
        "      tf.compat.v1.losses.softmax_cross_entropy(target.nodes, output.nodes) +\n",
        "      tf.compat.v1.losses.softmax_cross_entropy(target.edges, output.edges)\n",
        "      for output in outputs\n",
        "  ]\n",
        "  return tf.stack(losss)\n",
        "\n",
        "\n",
        "\n",
        "def plot_linked_list(ax, graph, sort_indices):\n",
        "  \"\"\"Plot a networkx graph containing weights for the linked list probability.\"\"\"\n",
        "  nd = len(graph.nodes())\n",
        "  probs = np.zeros((nd, nd))\n",
        "  for edge in graph.edges(data=True):\n",
        "    probs[edge[0], edge[1]] = edge[2][\"features\"][0]\n",
        "  ax.matshow(probs[sort_indices][:, sort_indices], cmap=\"viridis\")\n",
        "  ax.grid(False)\n",
        "\n",
        "\n",
        "# pylint: enable=redefined-outer-name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "NRlAUqM5snzr"
      },
      "outputs": [],
      "source": [
        "#@title Visualize the sort task  { form-width: \"30%\" }\n",
        "num_elements_min_max = (5, 10)\n",
        "\n",
        "inputs, targets, sort_indices, ranks = create_data(\n",
        "    1, num_elements_min_max)\n",
        "\n",
        "inputs_nodes = inputs.nodes.numpy()\n",
        "targets = utils_tf.nest_to_numpy(targets)\n",
        "sort_indices_nodes = sort_indices.nodes.numpy()\n",
        "ranks_nodes = ranks.nodes.numpy()\n",
        "\n",
        "sort_indices = np.squeeze(sort_indices_nodes).astype(int)\n",
        "\n",
        "# Plot sort linked lists.\n",
        "# The matrix plots show each element from the sorted list (rows), and which\n",
        "# element they link to as next largest (columns). Ground truth is a diagonal\n",
        "# offset toward the upper-right by one.\n",
        "fig = plt.figure(1, figsize=(4, 4))\n",
        "fig.clf()\n",
        "ax = fig.add_subplot(1, 1, 1)\n",
        "plot_linked_list(ax,\n",
        "                 utils_np.graphs_tuple_to_networkxs(targets)[0], sort_indices)\n",
        "ax.set_title(\"Element-to-element links for sorted elements\")\n",
        "ax.set_axis_off()\n",
        "\n",
        "fig = plt.figure(2, figsize=(10, 2))\n",
        "fig.clf()\n",
        "ax1 = fig.add_subplot(1, 3, 1)\n",
        "ax2 = fig.add_subplot(1, 3, 2)\n",
        "ax3 = fig.add_subplot(1, 3, 3)\n",
        "\n",
        "i = 0\n",
        "num_elements = ranks_nodes.shape[0]\n",
        "inputs = np.squeeze(inputs_nodes)\n",
        "ranks = np.squeeze(ranks_nodes * (num_elements - 1.0)).astype(int)\n",
        "x = np.arange(inputs.shape[0])\n",
        "\n",
        "ax1.set_title(\"Inputs\")\n",
        "ax1.barh(x, inputs, color=\"b\")\n",
        "ax1.set_xlim(-0.01, 1.01)\n",
        "\n",
        "ax2.set_title(\"Sorted\")\n",
        "ax2.barh(x, inputs[sort_indices], color=\"k\")\n",
        "ax2.set_xlim(-0.01, 1.01)\n",
        "\n",
        "ax3.set_title(\"Ranks\")\n",
        "ax3.barh(x, ranks, color=\"r\")\n",
        "_ = ax3.set_xlim(0, len(ranks) + 0.5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "hKdqAev4ERPw"
      },
      "outputs": [],
      "source": [
        "#@title Set up model training and evaluation  { form-width: \"30%\" }\n",
        "\n",
        "# The model we explore includes three components:\n",
        "# - An \"Encoder\" graph net, which independently encodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.).\n",
        "# - A \"Core\" graph net, which performs N rounds of processing (message-passing)\n",
        "#   steps. The input to the Core is the concatenation of the Encoder's output\n",
        "#   and the previous output of the Core (labeled \"Hidden(t)\" below, where \"t\" is\n",
        "#   the processing step).\n",
        "# - A \"Decoder\" graph net, which independently decodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.), on each\n",
        "#   message-passing step.\n",
        "#\n",
        "#                     Hidden(t)   Hidden(t+1)\n",
        "#                        |            ^\n",
        "#           *---------*  |  *------*  |  *---------*\n",
        "#           |         |  |  |      |  |  |         |\n",
        "# Input ---\u003e| Encoder |  *-\u003e| Core |--*-\u003e| Decoder |---\u003e Output(t)\n",
        "#           |         |----\u003e|      |     |         |\n",
        "#           *---------*     *------*     *---------*\n",
        "#\n",
        "# The model is trained by supervised learning. Input graphs are procedurally\n",
        "# generated, and output graphs have the same structure with the nodes and edges\n",
        "# of the linked list labeled (using 2-element 1-hot vectors). The target\n",
        "# labels the node corresponding to the lowest value in the list, and labels each\n",
        "# which represents the connection between neighboring values in the sorted\n",
        "# list.\n",
        "#\n",
        "# The training loss is computed on the output of each processing step. The\n",
        "# reason for this is to encourage the model to try to solve the problem in as\n",
        "# few steps as possible. It also helps make the output of intermediate steps\n",
        "# more interpretable.\n",
        "#\n",
        "# There's no need for a separate evaluate dataset because the inputs are\n",
        "# never repeated, so the training loss is the measure of performance on graphs\n",
        "# from the input distribution.\n",
        "#\n",
        "# We also evaluate how well the models generalize to lists which are up to\n",
        "# twice as large as those on which it was trained. The loss is computed only\n",
        "# on the final processing step.\n",
        "#\n",
        "# Variables with the suffix _tr are training parameters, and variables with the\n",
        "# suffix _ge are test/generalization parameters.\n",
        "#\n",
        "# After around 2000-5000 training iterations the model reaches near-perfect\n",
        "# performance on lists with between 8-16 elements.\n",
        "\n",
        "# Model parameters.\n",
        "# Number of processing (message-passing) steps.\n",
        "num_processing_steps_tr = 10\n",
        "num_processing_steps_ge = 10\n",
        "\n",
        "# Data / training parameters.\n",
        "num_training_iterations = 3000\n",
        "batch_size_tr = 32\n",
        "batch_size_ge = 100\n",
        "# Number of elements in each list is sampled uniformly from this range.\n",
        "num_elements_min_max_tr = (8, 17)\n",
        "num_elements_min_max_ge = (16, 33)\n",
        "\n",
        "# Data.\n",
        "@tf.function\n",
        "def get_data():\n",
        "  inputs_tr, targets_tr, sort_indices_tr, _ = create_data(\n",
        "      batch_size_tr, num_elements_min_max_tr)\n",
        "  inputs_tr = utils_tf.set_zero_edge_features(inputs_tr, 1)\n",
        "  inputs_tr = utils_tf.set_zero_global_features(inputs_tr, 1)\n",
        "  # Test/generalization.\n",
        "  inputs_ge, targets_ge, sort_indices_ge, _ = create_data(\n",
        "      batch_size_ge, num_elements_min_max_ge)\n",
        "  inputs_ge = utils_tf.set_zero_edge_features(inputs_ge, 1)\n",
        "  inputs_ge = utils_tf.set_zero_global_features(inputs_ge, 1)\n",
        "\n",
        "  targets_tr = utils_tf.set_zero_global_features(targets_tr, 1)\n",
        "  targets_ge = utils_tf.set_zero_global_features(targets_ge, 1)\n",
        "\n",
        "  return inputs_tr, targets_tr, sort_indices_tr, inputs_ge, targets_ge, sort_indices_ge\n",
        "\n",
        "# Optimizer.\n",
        "learning_rate = 1e-3\n",
        "optimizer = snt.optimizers.Adam(learning_rate)\n",
        "\n",
        "model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)\n",
        "last_iteration = 0\n",
        "logged_iterations = []\n",
        "losses_tr = []\n",
        "corrects_tr = []\n",
        "solveds_tr = []\n",
        "losses_ge = []\n",
        "corrects_ge = []\n",
        "solveds_ge = []\n",
        "\n",
        "\n",
        "# Training.\n",
        "def update_step(inputs_tr, targets_tr):\n",
        "  with tf.GradientTape() as tape:\n",
        "    outputs_tr = model(inputs_tr, num_processing_steps_tr)\n",
        "    # Loss.\n",
        "    loss_tr = create_loss(targets_tr, outputs_tr)\n",
        "    loss_tr = tf.math.reduce_sum(loss_tr) / num_processing_steps_tr\n",
        "\n",
        "  gradients = tape.gradient(loss_tr, model.trainable_variables)\n",
        "  optimizer.apply(gradients, model.trainable_variables)\n",
        "  return outputs_tr, loss_tr\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "_1fCATarkOdH"
      },
      "outputs": [],
      "source": [
        "#@title Compiling update_step with `tf_function`  { form-width: \"30%\" }\n",
        "\"\"\"To recover the speed of TensorFlow 1, we need to use `tf.function` to\n",
        "compile the update_step into a graph. However, using `tf.function` naively\n",
        "may futher reduce performance as the number of nodes/edges in the batch can\n",
        "change across batches, tensorflow will repeatedly trace the function multiple\n",
        "times for each unique shape of the input tensors.\n",
        "Instead, we obtain an explicit signature for each input argument to\n",
        "`update_step` using `utils_tf.specs_from_graphs_tuple` that sets `None` sizes\n",
        "for variable length axes in graph fields, preventing `tf.function` from having\n",
        "to trace the function repeatedly.\"\"\"\n",
        "\n",
        "# Get some example data that resembles the tensors that will be fed\n",
        "# into update_step():\n",
        "example_input_data, example_target_data = get_data()[:2]\n",
        "\n",
        "# Get the input signature for that function by obtaining the specs\n",
        "input_signature = [\n",
        "  utils_tf.specs_from_graphs_tuple(example_input_data),\n",
        "  utils_tf.specs_from_graphs_tuple(example_target_data)\n",
        "]\n",
        "\n",
        "# Compile the update function using the input signature for speedy code.\n",
        "compiled_update_step = tf.function(update_step, input_signature=input_signature)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "RXQFg5oLyj2Q"
      },
      "outputs": [],
      "source": [
        "#@title Run training steps  { form-width: \"30%\" }\n",
        "\n",
        "# You can interrupt this cell's training loop at any time, and visualize the\n",
        "# intermediate results by running the next cell (below). You can then resume\n",
        "# training by simply executing this cell again.\n",
        "\n",
        "# Instantiate the model.\n",
        "\n",
        "# How much time between logging and printing the current results.\n",
        "log_every_seconds = 20\n",
        "\n",
        "print(\"# (iteration number), T (elapsed seconds), \"\n",
        "      \"Ltr (training loss), Lge (test/generalization loss), \"\n",
        "      \"Ctr (training fraction nodes/edges labeled correctly), \"\n",
        "      \"Str (training fraction examples solved correctly), \"\n",
        "      \"Cge (test/generalization fraction nodes/edges labeled correctly), \"\n",
        "      \"Sge (test/generalization fraction examples solved correctly)\")\n",
        "\n",
        "start_time = time.time()\n",
        "last_log_time = start_time\n",
        "for iteration in range(last_iteration, num_training_iterations):\n",
        "  last_iteration = iteration\n",
        "  (inputs_tr, targets_tr, sort_indices_tr,\n",
        "   inputs_ge, targets_ge, sort_indices_ge) = get_data()\n",
        "\n",
        "  outputs_tr, loss_tr = compiled_update_step(inputs_tr, targets_tr)\n",
        "\n",
        "  the_time = time.time()\n",
        "  elapsed_since_last_log = the_time - last_log_time\n",
        "  if elapsed_since_last_log \u003e log_every_seconds:\n",
        "    last_log_time = the_time\n",
        "    outputs_ge = model(inputs_ge, num_processing_steps_ge)\n",
        "    losss_ge = create_loss(targets_ge, outputs_ge)\n",
        "    loss_ge = losss_ge[-1]\n",
        "\n",
        "    # Replace the globals again to prevent exceptions.\n",
        "    outputs_tr[-1] = outputs_tr[-1].replace(globals=None)\n",
        "    targets_tr = targets_tr.replace(globals=None)\n",
        "\n",
        "    correct_tr, solved_tr = compute_accuracy(\n",
        "        utils_tf.nest_to_numpy(targets_tr),\n",
        "        utils_tf.nest_to_numpy(outputs_tr[-1]))\n",
        "    correct_ge, solved_ge = compute_accuracy(\n",
        "        utils_tf.nest_to_numpy(targets_ge),\n",
        "        utils_tf.nest_to_numpy(outputs_ge[-1]))\n",
        "    elapsed = time.time() - start_time\n",
        "    losses_tr.append(loss_tr.numpy())\n",
        "    corrects_tr.append(correct_tr)\n",
        "    solveds_tr.append(solved_tr)\n",
        "    losses_ge.append(loss_ge.numpy())\n",
        "    corrects_ge.append(correct_ge)\n",
        "    solveds_ge.append(solved_ge)\n",
        "    logged_iterations.append(iteration)\n",
        "    print(\"# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, \"\n",
        "          \"Str {:.4f}, Cge {:.4f}, Sge {:.4f}\".format(\n",
        "              iteration, elapsed, loss_tr.numpy(), loss_ge.numpy(),\n",
        "              correct_tr, solved_tr, correct_ge, solved_ge))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "ZdnzxC1R6cc3"
      },
      "outputs": [],
      "source": [
        "#@title Visualize results  { form-width: \"30%\" }\n",
        "\n",
        "# This cell visualizes the results of training. You can visualize the\n",
        "# intermediate results by interrupting execution of the cell above, and running\n",
        "# this cell. You can then resume training by simply executing the above cell\n",
        "# again.\n",
        "\n",
        "# Plot results curves.\n",
        "fig = plt.figure(11, figsize=(18, 3))\n",
        "fig.clf()\n",
        "x = np.array(logged_iterations)\n",
        "# Loss.\n",
        "y_tr = losses_tr\n",
        "y_ge = losses_ge\n",
        "ax = fig.add_subplot(1, 3, 1)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Loss across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Loss (binary cross-entropy)\")\n",
        "ax.legend()\n",
        "# Correct.\n",
        "y_tr = corrects_tr\n",
        "y_ge = corrects_ge\n",
        "ax = fig.add_subplot(1, 3, 2)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Fraction correct across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Fraction nodes/edges correct\")\n",
        "# Solved.\n",
        "y_tr = solveds_tr\n",
        "y_ge = solveds_ge\n",
        "ax = fig.add_subplot(1, 3, 3)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Fraction solved across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Fraction examples solved\")\n",
        "\n",
        "# Plot sort linked lists for test/generalization.\n",
        "# The matrix plots show each element from the sorted list (rows), and which\n",
        "# element they link to as next largest (columns). Ground truth is a diagonal\n",
        "# offset toward the upper-right by one.\n",
        "outputs = utils_np.graphs_tuple_to_networkxs(outputs_tr[-1])\n",
        "targets = utils_np.graphs_tuple_to_networkxs(targets_tr)\n",
        "inputs = utils_np.graphs_tuple_to_networkxs(inputs_tr)\n",
        "batch_element = 0\n",
        "fig = plt.figure(12, figsize=(8, 4.5))\n",
        "fig.clf()\n",
        "ax1 = fig.add_subplot(1, 2, 1)\n",
        "ax2 = fig.add_subplot(1, 2, 2)\n",
        "sort_indices = np.squeeze(\n",
        "    utils_np.get_graph(sort_indices_tr,\n",
        "                       batch_element).nodes).astype(int)\n",
        "fig.suptitle(\"Element-to-element link predictions for sorted elements\")\n",
        "plot_linked_list(ax1, targets[batch_element], sort_indices)\n",
        "ax1.set_title(\"Ground truth\")\n",
        "ax1.set_axis_off()\n",
        "plot_linked_list(ax2, outputs[batch_element], sort_indices)\n",
        "ax2.set_title(\"Predicted\")\n",
        "ax2.set_axis_off()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "sort.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
